import torch
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torchvision
from PIL import Image
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd
from torchvision.transforms import functional
import random
from aip import AipFace
import base64
import time

""" 你的 APPID AK SK """
APP_ID = '22912073'
API_KEY = '3zbPMNiqWOrsD5BspgX1pBoR'
SECRET_KEY = 'nxSoiCZnSdOV9DVTPhHrlMvAYvTHZNa4'
# aip_face对象
aipFace = AipFace(APP_ID, API_KEY, SECRET_KEY)

faces_name_dict = {
    'Paulmale': 'Paul Male',
    'Danielmale': 'Daniel Male',
    'Fishermale': 'Fisher Male',
    'Jackmale': 'Jack Male',
    'Kevinmale': 'Kevin Male',
    'lilyfemale': 'lily Female',
    'Rosefemale': 'Rose Female',
    'Maryfemale': 'Mary Female',
    'Michaelmale': 'Michael Male',
    'stevenmale': 'Steven Male',
    'Jamesmale': 'James Male'
}


def predict_faces(img_path):
    faces = []

    # 读取图片，转base64
    filepath = img_path
    with open(filepath, "rb") as fp:
        base64_data = base64.b64encode(fp.read())
    image = str(base64_data, 'utf-8')
    imageType = "BASE64"

    # 配置参数
    groupIdList = 'main'
    options = {}
    options["max_face_num"] = 10
    options["match_threshold"] = 10
    options["quality_control"] = "NONE"
    options["liveness_control"] = "NONE"
    # options["user_id"] = "233451"
    options["max_user_num"] = 1

    """ 带参数调用人脸搜索 """
    response = aipFace.multiSearch(image, imageType, groupIdList, options)
    if response['error_msg'] == 'pic not has face':
        return faces
    print(response)
    result = response['result']
    face_list = result['face_list']
    print(result)

    for idx in range(result['face_num']):
        face_dict = {}
        location = face_list[idx]['location']
        score = face_list[idx]['user_list'][0]['score']
        user_id = face_list[idx]['user_list'][0]['user_info']
        user_id = user_id.split('.')[0]
        user_id = faces_name_dict[user_id]

        face_dict['location'] = location
        face_dict['user_id'] = user_id
        face_dict['score'] = score

        faces.append(face_dict)

    # time.sleep(0.5)
    return faces


def predict_once(model, img):
    # 将模型切换成预测模式
    model.eval()

    img_tensor = functional.to_tensor(img)

    with torch.no_grad():
        # 下方是原注释
        '''
        prediction形如：
        [{'boxes': tensor([[1492.6672,  238.4670, 1765.5385,  315.0320],
        [ 887.1390,  256.8106, 1154.6687,  330.2953]], device='cuda:0'), 
        'labels': tensor([1, 1], device='cuda:0'), 
        'scores': tensor([1.0000, 1.0000], device='cuda:0')}]
        '''
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        prediction = model([img_tensor.to(device)])

    # 先将预测结果打印出来看一下
    print(prediction)

    return prediction


def random_color():
    b = random.randint(50, 200)
    g = random.randint(50, 200)
    r = random.randint(50, 200)
    return (b, g, r)

def main(args):
    root = r'dataset'
    score_threshold = 0.5
    model_selected = '32'
    # model_selected = 'last_model'

    # 解析label_list文件
    with open(os.path.join(root, "label_list.txt"), 'r') as file:
        label_list = file.readlines()
    # map(str.rstrip, label_list)  #  去掉末尾的\n  # map中，传进去一个函数，而不是传进去一个函数的返回值
    label_list = [label.rstrip() for label in label_list]  # 去掉空字符
    label_list = [label for label in label_list if label != '']  # 去掉空行

    # 自动设置类别数量
    num_classes = len(label_list)
    print(label_list)
    print(num_classes)

    # 加载模型
    # model_selected = '15'
    model_path = os.path.join(root, 'models', model_selected + '.pkl')
    # model_path = os.path.join(root, 'models', model_selected + '.pkl')
    # model = torch.load(r'test4\models\last_model.pkl')
    model = torch.load(model_path)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    # 设置为检测模式
    model.eval()

    # 对每一个图片画框
    img_root = os.path.join(root, 'test')
    img_names = os.listdir(img_root)
    img_name_count = 0
    for img_name in img_names:
        img_name_count += 1
        src = os.path.join(img_root, img_name)
        dst = os.path.join(img_root, str(img_name_count) + '_' + str(int(time.time())) + '.jpg')
        os.rename(src, dst)
    img_names = os.listdir(img_root)
    # print(img_names)
    for img_name in img_names:
        # 记录时间
        time_start = time.time()
        print(time_start)

        # cv2读取照片
        img_path = os.path.join(img_root, img_name)
        src_img = cv2.imread(img_path)
        img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)

        # 用模型predict一次
        prediction = predict_once(model=model, img=img)
        print(prediction[0]['labels'])
        print(prediction[0]['scores'])

        # 解析prediction
        boxes = prediction[0]['boxes']
        scores = prediction[0]['scores']
        labels = prediction[0]['labels']
        labels = [label_list[label] for label in labels]

        # 画框字典
        draw_list = {}
        draw_list['label'] = []
        draw_list['max_score'] = []
        draw_list['idx'] = []

        # 搜索最佳识别结果
        for idx in range(boxes.shape[0]):


            draw_dict = {}
            if labels[idx] not in draw_list['label']:
                draw_list['label'].append(labels[idx])
                draw_list['max_score'].append(float(0))
                draw_list['idx'].append(int(idx))
            draw_idx = draw_list['label'].index(labels[idx])
            if scores[idx] > draw_list['max_score'][draw_idx]:
                draw_list['max_score'][draw_idx] = float(scores[idx])
                draw_list['idx'][draw_idx] = int(idx)

        # draw_list按照score排序
        draw_dict = sorted(draw_dict.items(), key=lambda kv: (kv[1], kv[0]))

        print(draw_list)
        # 只画前15个框
        count = 0
        for idx in draw_list['idx']:
            count += 1
            if count > 15:
                break
            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            color = random_color()
            cv2.rectangle(src_img, (int(x1), int(y1)), (int(x2), int(y2)), color=color, thickness=2)
            cv2.rectangle(src_img, (int(x1), int(y1)), (int(x2), int(y1 - 16)), color=color, thickness=-1)
            cv2.putText(src_img, labels[idx], (int(x1), int(y1)), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5,
                        color=(0, 0, 0))
            cv2.putText(src_img, str(round(float(scores[idx]), 2)), (int(x1), int(y2)), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1,
                        color=color, thickness=2)

        # 人脸检测
        faces = predict_faces(img_path)

        # 画人脸框
        for face in faces:
            color = random_color()
            location = face['location']
            score = face['score']
            user_id = face['user_id']

            left = location['left']
            top = location['top']
            width = location['width']
            height = location['height']

            cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top + height)), color=color, thickness=1)
            cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top - 16)), color=color,
                          thickness=-1)
            cv2.putText(src_img, user_id, (int(left), int(top)), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=0.5,
                        color=(0, 0, 0), thickness=1)
            cv2.putText(src_img, str(round(score, 2)), (int(left), int(top+height)),
                        fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
                        color=color, thickness=2)

        # 标注性别
        time_mid1 = time.time()
        if time_mid1 - time_start < 0.5:
            time.sleep(0.5 - (time_mid1 - time_start))
        '''
        faces = []
        # 读取图片，转base64
        filepath = img_path
        with open(filepath, "rb") as fp:
            base64_data = base64.b64encode(fp.read())
        image = str(base64_data, 'utf-8')
        imageType = "BASE64"

        # 配置参数
        groupIdList = 'main'
        options = {}
        options["face_field"] = "age,gender"
        options["max_face_num"] = 10
        options["face_type"] = "LIVE"
        # options["liveness_control"] = "LOW"

        """ 带参数调用人脸检测 """
        response = aipFace.detect(image, imageType, options)
        if not response['error_msg'] == 'pic not has face':
            print(response)
            result = response['result']
            face_list = result['face_list']
            print(result)

            for idx in range(result['face_num']):
                face_dict = {}
                location = face_list[idx]['location']
                left = location['left']
                top = location['top']
                width = location['width']
                height = location['height']
                gender = face_list[idx]['gender']['type']
                gender_prop = face_list[idx]['gender']['probability']

                face_dict['location'] = location
                face_dict['gender'] = gender
                face_dict['gender_prop'] = gender_prop
                faces.append(face_dict)

                cv2.putText(src_img, gender, (int(left), int(top + height/2)),
                            fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5,
                            color=color)
        '''

        # 真正显示图片
        # src_img = cv2.resize(src_img, (1280, 720))
        cv2.imshow('result', src_img)
        # 保存图片
        print(os.path.join(root, 'prediction', img_name))
        cv2.imwrite(os.path.join(root, 'prediction', img_name), src_img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
        # 响应按键
        key = cv2.waitKey()
        if 27 == key:
            break

        time_end = time.time()
        if time_end - time_mid1 < 0.5:
            time.sleep(0.5 - (time_end - time_mid1))

    cv2.destroyAllWindows()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description=__doc__)

    # parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')
    # parser.add_argument('-b', '--batch-size', default=2, type=int,
    #                     help='images per gpu, the total batch size is $NGPU x batch_size')
    parser.add_argument('--epochs', default=26, type=int, metavar='N',
                        help='number of total epochs to run')
    # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
    #                     help='number of data loading workers (default: 4)')
    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
